from common.np import *


def smooth_curve(x):
    """用于使损失函数的图形变圆滑

    参考：http://glowingpython.blogspot.jp/2012/02/convolution-with-numpy.html
    """
    window_len = 11
    s = np.r_[x[window_len-1:0:-1], x, x[-1:-window_len:-1]]
    w = np.kaiser(window_len, 2)
    y = np.convolve(w/w.sum(), s, mode='valid')
    return y[5:len(y)-5]


def shuffle_dataset(x, t):
    """
    打乱数据集
    :param x: 训练数据
    :param t: 监督数据
    :return: x, t 打乱的数据集
    """
    permutation = np.random.permutation(x.shape[0])
    if x.ndim ==2:
        x = x[permutation, :]
    else:
        x = x[permutation, :, :, :]
    t = t[permutation]
    return x, t


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    """
    图像转向量
    :param input_data: 输入数据
    :param filter_h: 滤波器高
    :param filter_w: 滤波器宽
    :param stride:  步幅
    :param pad:     填充
    :return: 2维数组
    """
    N, C, H, W = input_data.shape
    out_h = (H + 2 * pad - filter_h) // stride + 1
    out_w = (W + 2 * pad - filter_w) // stride + 1

    # 边缘填充
    img = np.pad(input_data, [(0, 0), (0, 0), (pad, pad), (pad, pad)], "constant")
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride * out_h
        for x in range(filter_w):
            x_max = x + stride * out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_h * out_w, -1)
    return col


def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
    """
    矩阵转图像
    :param col:
    :param input_shape: 输入数据的形状 (e.g. (10, 1, 28, 28))
    :param filter_h:
    :param filter_w:
    :param stride:
    :param pad:
    :return:
    """
    N, C, H, W = input_shape
    out_h = (H + 2*pad - filter_h) // stride + 1
    out_w = (W + 2*pad - filter_w) // stride + 1
    col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)

    img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
    return img[:, :, pad:H+pad, pad:W+pad]

